{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stochastic Gradient Descent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we are going to introduce the basic principles of stochastic gradient descent."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we use $x=10$ as the initial value and assume $\\eta=0.2$. Using gradient descent to iterate $x$ 10 times, we can see that, eventually, the value of $x$ approaches the optimal solution."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stochastic Gradient Descent (SGD)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In deep learning, the objective function is usually the average of the loss functions for each example in the training data set. We assume that $f_i(\\boldsymbol{x})$ is the loss function of the training data instance with $n$ examples, an index of $i$, and parameter vector of $\\boldsymbol{x}$, then we have the objective function\n",
    "\n",
    "$$f(\\boldsymbol{x}) = \\frac{1}{n} \\sum_{i = 1}^n f_i(\\boldsymbol{x}).$$\n",
    "\n",
    "The gradient of the objective function at $\\boldsymbol{x}$ is computed as\n",
    "\n",
    "$$\\nabla f(\\boldsymbol{x}) = \\frac{1}{n} \\sum_{i = 1}^n \\nabla f_i(\\boldsymbol{x}).$$\n",
    "\n",
    "If gradient descent is used, the computing cost for each independent variable iteration is $\\mathcal{O}(n)$, which grows linearly with $n$. Therefore, when the model training data instance is large, the cost of gradient descent for each iteration will be very high.\n",
    "\n",
    "Stochastic gradient descent (SGD) reduces computational cost at each iteration. At each iteration of stochastic gradient descent, we uniformly sample an index $i\\in{1,\\ldots,n}$ for data instances at random, and compute the gradient $\\nabla f_i(\\boldsymbol{x})$ to update $\\boldsymbol{x}$:\n",
    "\n",
    "$$\\boldsymbol{x} \\leftarrow \\boldsymbol{x} - \\eta \\nabla f_i(\\boldsymbol{x}).$$\n",
    "\n",
    "Here, $\\eta$ is the learning rate. We can see that the computing cost for each iteration drops from $\\mathcal{O}(n)$ of the gradient descent to the constant $\\mathcal{O}(1)$. We should mention that the stochastic gradient $\\nabla f_i(\\boldsymbol{x})$ is the unbiased estimate of gradient $\\nabla f(\\boldsymbol{x})$.\n",
    "\n",
    "$$\\mathbb{E}i \\nabla f_i(\\boldsymbol{x}) = \\frac{1}{n} \\sum{i = 1}^n \\nabla f_i(\\boldsymbol{x}) = \\nabla f(\\boldsymbol{x}).$$\n",
    "\n",
    "This means that, on average, the stochastic gradient is a good estimate of the gradient.\n",
    "\n",
    "Now, we will compare it to gradient descent by adding random noise with a mean of 0 to the gradient to simulate a SGD."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import d2l\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, x1 0.102315, x2 -0.103776\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Created with matplotlib (https://matplotlib.org/) -->\r\n",
       "<svg height=\"184.15625pt\" version=\"1.1\" viewBox=\"0 0 248.620313 184.15625\" width=\"248.620313pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       " <defs>\r\n",
       "  <style type=\"text/css\">\r\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;}\r\n",
       "  </style>\r\n",
       " </defs>\r\n",
       " <g id=\"figure_1\">\r\n",
       "  <g id=\"patch_1\">\r\n",
       "   <path d=\"M 0 184.15625 \r\n",
       "L 248.620313 184.15625 \r\n",
       "L 248.620313 -0 \r\n",
       "L 0 -0 \r\n",
       "z\r\n",
       "\" style=\"fill:none;\"/>\r\n",
       "  </g>\r\n",
       "  <g id=\"axes_1\">\r\n",
       "   <g id=\"patch_2\">\r\n",
       "    <path d=\"M 42.620312 146.6 \r\n",
       "L 237.920313 146.6 \r\n",
       "L 237.920313 10.7 \r\n",
       "L 42.620312 10.7 \r\n",
       "z\r\n",
       "\" style=\"fill:#ffffff;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"matplotlib.axis_1\">\r\n",
       "    <g id=\"xtick_1\">\r\n",
       "     <g id=\"line2d_1\">\r\n",
       "      <defs>\r\n",
       "       <path d=\"M 0 0 \r\n",
       "L 0 3.5 \r\n",
       "\" id=\"m64ad326b7a\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n",
       "      </defs>\r\n",
       "      <g>\r\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"88.39375\" xlink:href=\"#m64ad326b7a\" y=\"146.6\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "     <g id=\"text_1\">\r\n",
       "      <!-- −4 -->\r\n",
       "      <defs>\r\n",
       "       <path d=\"M 10.59375 35.5 \r\n",
       "L 73.1875 35.5 \r\n",
       "L 73.1875 27.203125 \r\n",
       "L 10.59375 27.203125 \r\n",
       "z\r\n",
       "\" id=\"DejaVuSans-8722\"/>\r\n",
       "       <path d=\"M 37.796875 64.3125 \r\n",
       "L 12.890625 25.390625 \r\n",
       "L 37.796875 25.390625 \r\n",
       "z\r\n",
       "M 35.203125 72.90625 \r\n",
       "L 47.609375 72.90625 \r\n",
       "L 47.609375 25.390625 \r\n",
       "L 58.015625 25.390625 \r\n",
       "L 58.015625 17.1875 \r\n",
       "L 47.609375 17.1875 \r\n",
       "L 47.609375 0 \r\n",
       "L 37.796875 0 \r\n",
       "L 37.796875 17.1875 \r\n",
       "L 4.890625 17.1875 \r\n",
       "L 4.890625 26.703125 \r\n",
       "z\r\n",
       "\" id=\"DejaVuSans-52\"/>\r\n",
       "      </defs>\r\n",
       "      <g transform=\"translate(81.022656 161.198437)scale(0.1 -0.1)\">\r\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\r\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-52\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "    <g id=\"xtick_2\">\r\n",
       "     <g id=\"line2d_2\">\r\n",
       "      <g>\r\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"149.425\" xlink:href=\"#m64ad326b7a\" y=\"146.6\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "     <g id=\"text_2\">\r\n",
       "      <!-- −2 -->\r\n",
       "      <defs>\r\n",
       "       <path d=\"M 19.1875 8.296875 \r\n",
       "L 53.609375 8.296875 \r\n",
       "L 53.609375 0 \r\n",
       "L 7.328125 0 \r\n",
       "L 7.328125 8.296875 \r\n",
       "Q 12.9375 14.109375 22.625 23.890625 \r\n",
       "Q 32.328125 33.6875 34.8125 36.53125 \r\n",
       "Q 39.546875 41.84375 41.421875 45.53125 \r\n",
       "Q 43.3125 49.21875 43.3125 52.78125 \r\n",
       "Q 43.3125 58.59375 39.234375 62.25 \r\n",
       "Q 35.15625 65.921875 28.609375 65.921875 \r\n",
       "Q 23.96875 65.921875 18.8125 64.3125 \r\n",
       "Q 13.671875 62.703125 7.8125 59.421875 \r\n",
       "L 7.8125 69.390625 \r\n",
       "Q 13.765625 71.78125 18.9375 73 \r\n",
       "Q 24.125 74.21875 28.421875 74.21875 \r\n",
       "Q 39.75 74.21875 46.484375 68.546875 \r\n",
       "Q 53.21875 62.890625 53.21875 53.421875 \r\n",
       "Q 53.21875 48.921875 51.53125 44.890625 \r\n",
       "Q 49.859375 40.875 45.40625 35.40625 \r\n",
       "Q 44.1875 33.984375 37.640625 27.21875 \r\n",
       "Q 31.109375 20.453125 19.1875 8.296875 \r\n",
       "z\r\n",
       "\" id=\"DejaVuSans-50\"/>\r\n",
       "      </defs>\r\n",
       "      <g transform=\"translate(142.053906 161.198437)scale(0.1 -0.1)\">\r\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\r\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-50\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "    <g id=\"xtick_3\">\r\n",
       "     <g id=\"line2d_3\">\r\n",
       "      <g>\r\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"210.45625\" xlink:href=\"#m64ad326b7a\" y=\"146.6\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "     <g id=\"text_3\">\r\n",
       "      <!-- 0 -->\r\n",
       "      <defs>\r\n",
       "       <path d=\"M 31.78125 66.40625 \r\n",
       "Q 24.171875 66.40625 20.328125 58.90625 \r\n",
       "Q 16.5 51.421875 16.5 36.375 \r\n",
       "Q 16.5 21.390625 20.328125 13.890625 \r\n",
       "Q 24.171875 6.390625 31.78125 6.390625 \r\n",
       "Q 39.453125 6.390625 43.28125 13.890625 \r\n",
       "Q 47.125 21.390625 47.125 36.375 \r\n",
       "Q 47.125 51.421875 43.28125 58.90625 \r\n",
       "Q 39.453125 66.40625 31.78125 66.40625 \r\n",
       "z\r\n",
       "M 31.78125 74.21875 \r\n",
       "Q 44.046875 74.21875 50.515625 64.515625 \r\n",
       "Q 56.984375 54.828125 56.984375 36.375 \r\n",
       "Q 56.984375 17.96875 50.515625 8.265625 \r\n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \r\n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \r\n",
       "Q 6.59375 17.96875 6.59375 36.375 \r\n",
       "Q 6.59375 54.828125 13.0625 64.515625 \r\n",
       "Q 19.53125 74.21875 31.78125 74.21875 \r\n",
       "z\r\n",
       "\" id=\"DejaVuSans-48\"/>\r\n",
       "      </defs>\r\n",
       "      <g transform=\"translate(207.275 161.198437)scale(0.1 -0.1)\">\r\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "    <g id=\"text_4\">\r\n",
       "     <!-- x1 -->\r\n",
       "     <defs>\r\n",
       "      <path d=\"M 54.890625 54.6875 \r\n",
       "L 35.109375 28.078125 \r\n",
       "L 55.90625 0 \r\n",
       "L 45.3125 0 \r\n",
       "L 29.390625 21.484375 \r\n",
       "L 13.484375 0 \r\n",
       "L 2.875 0 \r\n",
       "L 24.125 28.609375 \r\n",
       "L 4.6875 54.6875 \r\n",
       "L 15.28125 54.6875 \r\n",
       "L 29.78125 35.203125 \r\n",
       "L 44.28125 54.6875 \r\n",
       "z\r\n",
       "\" id=\"DejaVuSans-120\"/>\r\n",
       "      <path d=\"M 12.40625 8.296875 \r\n",
       "L 28.515625 8.296875 \r\n",
       "L 28.515625 63.921875 \r\n",
       "L 10.984375 60.40625 \r\n",
       "L 10.984375 69.390625 \r\n",
       "L 28.421875 72.90625 \r\n",
       "L 38.28125 72.90625 \r\n",
       "L 38.28125 8.296875 \r\n",
       "L 54.390625 8.296875 \r\n",
       "L 54.390625 0 \r\n",
       "L 12.40625 0 \r\n",
       "z\r\n",
       "\" id=\"DejaVuSans-49\"/>\r\n",
       "     </defs>\r\n",
       "     <g transform=\"translate(134.129687 174.876562)scale(0.1 -0.1)\">\r\n",
       "      <use xlink:href=\"#DejaVuSans-120\"/>\r\n",
       "      <use x=\"59.179688\" xlink:href=\"#DejaVuSans-49\"/>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "   </g>\r\n",
       "   <g id=\"matplotlib.axis_2\">\r\n",
       "    <g id=\"ytick_1\">\r\n",
       "     <g id=\"line2d_4\">\r\n",
       "      <defs>\r\n",
       "       <path d=\"M 0 0 \r\n",
       "L -3.5 0 \r\n",
       "\" id=\"m730e5a5485\" style=\"stroke:#000000;stroke-width:0.8;\"/>\r\n",
       "      </defs>\r\n",
       "      <g>\r\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m730e5a5485\" y=\"146.6\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "     <g id=\"text_5\">\r\n",
       "      <!-- −3 -->\r\n",
       "      <defs>\r\n",
       "       <path d=\"M 40.578125 39.3125 \r\n",
       "Q 47.65625 37.796875 51.625 33 \r\n",
       "Q 55.609375 28.21875 55.609375 21.1875 \r\n",
       "Q 55.609375 10.40625 48.1875 4.484375 \r\n",
       "Q 40.765625 -1.421875 27.09375 -1.421875 \r\n",
       "Q 22.515625 -1.421875 17.65625 -0.515625 \r\n",
       "Q 12.796875 0.390625 7.625 2.203125 \r\n",
       "L 7.625 11.71875 \r\n",
       "Q 11.71875 9.328125 16.59375 8.109375 \r\n",
       "Q 21.484375 6.890625 26.8125 6.890625 \r\n",
       "Q 36.078125 6.890625 40.9375 10.546875 \r\n",
       "Q 45.796875 14.203125 45.796875 21.1875 \r\n",
       "Q 45.796875 27.640625 41.28125 31.265625 \r\n",
       "Q 36.765625 34.90625 28.71875 34.90625 \r\n",
       "L 20.21875 34.90625 \r\n",
       "L 20.21875 43.015625 \r\n",
       "L 29.109375 43.015625 \r\n",
       "Q 36.375 43.015625 40.234375 45.921875 \r\n",
       "Q 44.09375 48.828125 44.09375 54.296875 \r\n",
       "Q 44.09375 59.90625 40.109375 62.90625 \r\n",
       "Q 36.140625 65.921875 28.71875 65.921875 \r\n",
       "Q 24.65625 65.921875 20.015625 65.03125 \r\n",
       "Q 15.375 64.15625 9.8125 62.3125 \r\n",
       "L 9.8125 71.09375 \r\n",
       "Q 15.4375 72.65625 20.34375 73.4375 \r\n",
       "Q 25.25 74.21875 29.59375 74.21875 \r\n",
       "Q 40.828125 74.21875 47.359375 69.109375 \r\n",
       "Q 53.90625 64.015625 53.90625 55.328125 \r\n",
       "Q 53.90625 49.265625 50.4375 45.09375 \r\n",
       "Q 46.96875 40.921875 40.578125 39.3125 \r\n",
       "z\r\n",
       "\" id=\"DejaVuSans-51\"/>\r\n",
       "      </defs>\r\n",
       "      <g transform=\"translate(20.878125 150.399219)scale(0.1 -0.1)\">\r\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\r\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-51\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "    <g id=\"ytick_2\">\r\n",
       "     <g id=\"line2d_5\">\r\n",
       "      <g>\r\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m730e5a5485\" y=\"113.453659\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "     <g id=\"text_6\">\r\n",
       "      <!-- −2 -->\r\n",
       "      <g transform=\"translate(20.878125 117.252877)scale(0.1 -0.1)\">\r\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\r\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-50\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "    <g id=\"ytick_3\">\r\n",
       "     <g id=\"line2d_6\">\r\n",
       "      <g>\r\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m730e5a5485\" y=\"80.307317\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "     <g id=\"text_7\">\r\n",
       "      <!-- −1 -->\r\n",
       "      <g transform=\"translate(20.878125 84.106536)scale(0.1 -0.1)\">\r\n",
       "       <use xlink:href=\"#DejaVuSans-8722\"/>\r\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-49\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "    <g id=\"ytick_4\">\r\n",
       "     <g id=\"line2d_7\">\r\n",
       "      <g>\r\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m730e5a5485\" y=\"47.160976\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "     <g id=\"text_8\">\r\n",
       "      <!-- 0 -->\r\n",
       "      <g transform=\"translate(29.257812 50.960194)scale(0.1 -0.1)\">\r\n",
       "       <use xlink:href=\"#DejaVuSans-48\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "    <g id=\"ytick_5\">\r\n",
       "     <g id=\"line2d_8\">\r\n",
       "      <g>\r\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m730e5a5485\" y=\"14.014634\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "     <g id=\"text_9\">\r\n",
       "      <!-- 1 -->\r\n",
       "      <g transform=\"translate(29.257812 17.813853)scale(0.1 -0.1)\">\r\n",
       "       <use xlink:href=\"#DejaVuSans-49\"/>\r\n",
       "      </g>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "    <g id=\"text_10\">\r\n",
       "     <!-- x2 -->\r\n",
       "     <g transform=\"translate(14.798437 84.790625)rotate(-90)scale(0.1 -0.1)\">\r\n",
       "      <use xlink:href=\"#DejaVuSans-120\"/>\r\n",
       "      <use x=\"59.179688\" xlink:href=\"#DejaVuSans-50\"/>\r\n",
       "     </g>\r\n",
       "    </g>\r\n",
       "   </g>\r\n",
       "   <g id=\"LineCollection_1\"/>\r\n",
       "   <g id=\"LineCollection_2\">\r\n",
       "    <path clip-path=\"url(#p3f04cb47da)\" d=\"M 237.920313 110.004647 \r\n",
       "L 237.381801 110.139024 \r\n",
       "L 234.86875 110.733959 \r\n",
       "L 231.817188 111.371388 \r\n",
       "L 228.765625 111.923827 \r\n",
       "L 225.714063 112.391276 \r\n",
       "L 222.6625 112.773734 \r\n",
       "L 219.610938 113.071201 \r\n",
       "L 216.559375 113.283677 \r\n",
       "L 213.507813 113.411163 \r\n",
       "L 210.45625 113.453659 \r\n",
       "L 210.45625 113.453659 \r\n",
       "L 210.45625 113.453659 \r\n",
       "L 207.404688 113.411163 \r\n",
       "L 204.353125 113.283677 \r\n",
       "L 201.301563 113.071201 \r\n",
       "L 198.25 112.773734 \r\n",
       "L 195.198438 112.391276 \r\n",
       "L 192.146875 111.923827 \r\n",
       "L 189.095313 111.371388 \r\n",
       "L 186.04375 110.733959 \r\n",
       "L 183.530699 110.139024 \r\n",
       "L 182.992188 110.004647 \r\n",
       "L 179.940625 109.153593 \r\n",
       "L 176.889063 108.212953 \r\n",
       "L 173.8375 107.182729 \r\n",
       "L 172.861 106.82439 \r\n",
       "L 170.785938 106.019408 \r\n",
       "L 167.734375 104.740906 \r\n",
       "L 164.998491 103.509756 \r\n",
       "L 164.682813 103.359091 \r\n",
       "L 161.63125 101.802217 \r\n",
       "L 158.672159 100.195122 \r\n",
       "L 158.579688 100.14166 \r\n",
       "L 155.528125 98.270496 \r\n",
       "L 153.383784 96.880488 \r\n",
       "L 152.476563 96.25185 \r\n",
       "L 149.425 94.023045 \r\n",
       "L 148.829573 93.565854 \r\n",
       "L 146.373438 91.540244 \r\n",
       "L 144.88314 90.25122 \r\n",
       "L 143.321875 88.79278 \r\n",
       "L 141.423125 86.936585 \r\n",
       "L 140.270313 85.711612 \r\n",
       "L 138.387434 83.621951 \r\n",
       "L 137.21875 82.201394 \r\n",
       "L 135.724107 80.307317 \r\n",
       "L 134.167188 78.126637 \r\n",
       "L 133.389338 76.992683 \r\n",
       "L 131.354963 73.678049 \r\n",
       "L 131.115625 73.236098 \r\n",
       "L 129.618632 70.363415 \r\n",
       "L 128.121639 67.04878 \r\n",
       "L 128.064063 66.898115 \r\n",
       "L 126.89892 63.734146 \r\n",
       "L 125.900227 60.419512 \r\n",
       "L 125.123466 57.104878 \r\n",
       "L 125.0125 56.441951 \r\n",
       "L 124.584211 53.790244 \r\n",
       "L 124.262993 50.47561 \r\n",
       "L 124.155921 47.160976 \r\n",
       "L 124.262993 43.846341 \r\n",
       "L 124.584211 40.531707 \r\n",
       "L 125.0125 37.88 \r\n",
       "L 125.123466 37.217073 \r\n",
       "L 125.900227 33.902439 \r\n",
       "L 126.89892 30.587805 \r\n",
       "L 128.064063 27.423836 \r\n",
       "L 128.121639 27.273171 \r\n",
       "L 129.618632 23.958537 \r\n",
       "L 131.115625 21.085854 \r\n",
       "L 131.354963 20.643902 \r\n",
       "L 133.389338 17.329268 \r\n",
       "L 134.167188 16.195315 \r\n",
       "L 135.724107 14.014634 \r\n",
       "L 137.21875 12.120557 \r\n",
       "L 138.387434 10.7 \r\n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"LineCollection_3\">\r\n",
       "    <path clip-path=\"url(#p3f04cb47da)\" d=\"M 237.920313 138.494213 \r\n",
       "L 234.86875 139.006475 \r\n",
       "L 231.817188 139.45847 \r\n",
       "L 228.765625 139.8502 \r\n",
       "L 227.655966 139.970732 \r\n",
       "L 225.714063 140.174262 \r\n",
       "L 222.6625 140.435944 \r\n",
       "L 219.610938 140.639474 \r\n",
       "L 216.559375 140.784852 \r\n",
       "L 213.507813 140.87208 \r\n",
       "L 210.45625 140.901155 \r\n",
       "L 207.404688 140.87208 \r\n",
       "L 204.353125 140.784852 \r\n",
       "L 201.301563 140.639474 \r\n",
       "L 198.25 140.435944 \r\n",
       "L 195.198438 140.174262 \r\n",
       "L 193.256534 139.970732 \r\n",
       "L 192.146875 139.8502 \r\n",
       "L 189.095313 139.45847 \r\n",
       "L 186.04375 139.006475 \r\n",
       "L 182.992188 138.494213 \r\n",
       "L 179.940625 137.921685 \r\n",
       "L 176.889063 137.288891 \r\n",
       "L 174.102853 136.656098 \r\n",
       "L 173.8375 136.593557 \r\n",
       "L 170.785938 135.811804 \r\n",
       "L 167.734375 134.96751 \r\n",
       "L 164.682813 134.060676 \r\n",
       "L 162.41875 133.341463 \r\n",
       "L 161.63125 133.081492 \r\n",
       "L 158.579688 132.00911 \r\n",
       "L 155.528125 130.871736 \r\n",
       "L 153.383784 130.026829 \r\n",
       "L 152.476563 129.654778 \r\n",
       "L 149.425 128.335689 \r\n",
       "L 146.373438 126.948955 \r\n",
       "L 145.876672 126.712195 \r\n",
       "L 143.321875 125.442761 \r\n",
       "L 140.270313 123.855968 \r\n",
       "L 139.426263 123.397561 \r\n",
       "L 137.21875 122.145366 \r\n",
       "L 134.167188 120.340732 \r\n",
       "L 133.748346 120.082927 \r\n",
       "L 131.115625 118.387067 \r\n",
       "L 128.697406 116.768293 \r\n",
       "L 128.064063 116.323647 \r\n",
       "L 125.0125 114.100416 \r\n",
       "L 124.155921 113.453659 \r\n",
       "L 121.960938 111.711351 \r\n",
       "L 120.047246 110.139024 \r\n",
       "L 118.909375 109.153593 \r\n",
       "L 116.308043 106.82439 \r\n",
       "L 115.857813 106.398223 \r\n",
       "L 112.903125 103.509756 \r\n",
       "L 112.80625 103.409313 \r\n",
       "L 109.801635 100.195122 \r\n",
       "L 109.754688 100.14166 \r\n",
       "L 106.976399 96.880488 \r\n",
       "L 106.703125 96.537595 \r\n",
       "L 104.403397 93.565854 \r\n",
       "L 103.651563 92.522358 \r\n",
       "L 102.061312 90.25122 \r\n",
       "L 100.6 87.997268 \r\n",
       "L 99.931164 86.936585 \r\n",
       "L 98.008262 83.621951 \r\n",
       "L 97.548438 82.753833 \r\n",
       "L 96.287125 80.307317 \r\n",
       "L 94.741 76.992683 \r\n",
       "L 94.496875 76.407747 \r\n",
       "L 93.387216 73.678049 \r\n",
       "L 92.198295 70.363415 \r\n",
       "L 91.445313 67.941182 \r\n",
       "L 91.174921 67.04878 \r\n",
       "L 90.325119 63.734146 \r\n",
       "L 89.629826 60.419512 \r\n",
       "L 89.089043 57.104878 \r\n",
       "L 88.702769 53.790244 \r\n",
       "L 88.471005 50.47561 \r\n",
       "L 88.39375 47.160976 \r\n",
       "L 88.471005 43.846341 \r\n",
       "L 88.702769 40.531707 \r\n",
       "L 89.089043 37.217073 \r\n",
       "L 89.629826 33.902439 \r\n",
       "L 90.325119 30.587805 \r\n",
       "L 91.174921 27.273171 \r\n",
       "L 91.445313 26.380769 \r\n",
       "L 92.198295 23.958537 \r\n",
       "L 93.387216 20.643902 \r\n",
       "L 94.496875 17.914204 \r\n",
       "L 94.741 17.329268 \r\n",
       "L 96.287125 14.014634 \r\n",
       "L 97.548438 11.568118 \r\n",
       "L 98.008262 10.7 \r\n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"LineCollection_4\">\r\n",
       "    <path clip-path=\"url(#p3f04cb47da)\" d=\"M 135.724107 146.6 \r\n",
       "L 134.167188 145.897747 \r\n",
       "L 131.115625 144.465151 \r\n",
       "L 128.697406 143.285366 \r\n",
       "L 128.064063 142.965533 \r\n",
       "L 125.0125 141.366367 \r\n",
       "L 122.442763 139.970732 \r\n",
       "L 121.960938 139.699534 \r\n",
       "L 118.909375 137.921685 \r\n",
       "L 116.808299 136.656098 \r\n",
       "L 115.857813 136.061965 \r\n",
       "L 112.80625 134.091947 \r\n",
       "L 111.679519 133.341463 \r\n",
       "L 109.754688 132.00911 \r\n",
       "L 106.976399 130.026829 \r\n",
       "L 106.703125 129.823892 \r\n",
       "L 103.651563 127.490119 \r\n",
       "L 102.663028 126.712195 \r\n",
       "L 100.6 125.019616 \r\n",
       "L 98.677098 123.397561 \r\n",
       "L 97.548438 122.403171 \r\n",
       "L 94.985125 120.082927 \r\n",
       "L 94.496875 119.62042 \r\n",
       "L 91.564205 116.768293 \r\n",
       "L 91.445313 116.647026 \r\n",
       "L 88.39375 113.453659 \r\n",
       "L 88.39375 113.453659 \r\n",
       "L 85.455208 110.139024 \r\n",
       "L 85.342188 110.004647 \r\n",
       "L 82.731815 106.82439 \r\n",
       "L 82.290625 106.256167 \r\n",
       "L 80.208382 103.509756 \r\n",
       "L 79.239063 102.153769 \r\n",
       "L 77.871121 100.195122 \r\n",
       "L 76.1875 97.628954 \r\n",
       "L 75.707479 96.880488 \r\n",
       "L 73.71882 93.565854 \r\n",
       "L 73.135938 92.522358 \r\n",
       "L 71.895192 90.25122 \r\n",
       "L 70.21851 86.936585 \r\n",
       "L 70.084375 86.648356 \r\n",
       "L 68.70625 83.621951 \r\n",
       "L 67.328125 80.307317 \r\n",
       "L 67.032813 79.522272 \r\n",
       "L 66.101283 76.992683 \r\n",
       "L 65.009145 73.678049 \r\n",
       "L 64.045493 70.363415 \r\n",
       "L 63.98125 70.108443 \r\n",
       "L 63.226224 67.04878 \r\n",
       "L 62.534117 63.734146 \r\n",
       "L 61.967848 60.419512 \r\n",
       "L 61.527416 57.104878 \r\n",
       "L 61.212822 53.790244 \r\n",
       "L 61.024066 50.47561 \r\n",
       "L 60.961147 47.160976 \r\n",
       "L 61.024066 43.846341 \r\n",
       "L 61.212822 40.531707 \r\n",
       "L 61.527416 37.217073 \r\n",
       "L 61.967848 33.902439 \r\n",
       "L 62.534117 30.587805 \r\n",
       "L 63.226224 27.273171 \r\n",
       "L 63.98125 24.213508 \r\n",
       "L 64.045493 23.958537 \r\n",
       "L 65.009145 20.643902 \r\n",
       "L 66.101283 17.329268 \r\n",
       "L 67.032813 14.799679 \r\n",
       "L 67.328125 14.014634 \r\n",
       "L 68.70625 10.7 \r\n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"LineCollection_5\">\r\n",
       "    <path clip-path=\"url(#p3f04cb47da)\" d=\"M 96.287125 146.6 \r\n",
       "L 94.496875 145.364035 \r\n",
       "L 91.564205 143.285366 \r\n",
       "L 91.445313 143.198139 \r\n",
       "L 88.39375 140.901155 \r\n",
       "L 87.188194 139.970732 \r\n",
       "L 85.342188 138.494213 \r\n",
       "L 83.099473 136.656098 \r\n",
       "L 82.290625 135.968155 \r\n",
       "L 79.274963 133.341463 \r\n",
       "L 79.239063 133.308967 \r\n",
       "L 76.1875 130.481779 \r\n",
       "L 75.707479 130.026829 \r\n",
       "L 73.135938 127.490119 \r\n",
       "L 72.364663 126.712195 \r\n",
       "L 70.084375 124.314375 \r\n",
       "L 69.23125 123.397561 \r\n",
       "L 67.032813 120.93 \r\n",
       "L 66.294013 120.082927 \r\n",
       "L 63.98125 117.307884 \r\n",
       "L 63.540818 116.768293 \r\n",
       "L 60.961147 113.453659 \r\n",
       "L 60.929688 113.411163 \r\n",
       "L 58.55625 110.139024 \r\n",
       "L 57.878125 109.153593 \r\n",
       "L 56.307024 106.82439 \r\n",
       "L 54.826563 104.504146 \r\n",
       "L 54.204399 103.509756 \r\n",
       "L 52.249029 100.195122 \r\n",
       "L 51.775 99.339732 \r\n",
       "L 50.438125 96.880488 \r\n",
       "L 48.7525 93.565854 \r\n",
       "L 48.723438 93.504472 \r\n",
       "L 47.211916 90.25122 \r\n",
       "L 45.785952 86.936585 \r\n",
       "L 45.671875 86.648356 \r\n",
       "L 44.496044 83.621951 \r\n",
       "L 43.320212 80.307317 \r\n",
       "L 42.620313 78.126637 \r\n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\r\n",
       "    <path clip-path=\"url(#p3f04cb47da)\" d=\"M 42.620313 16.195315 \r\n",
       "L 43.320212 14.014634 \r\n",
       "L 44.496044 10.7 \r\n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"LineCollection_6\">\r\n",
       "    <path clip-path=\"url(#p3f04cb47da)\" d=\"M 67.328125 146.6 \r\n",
       "L 67.032813 146.347189 \r\n",
       "L 63.98125 143.678628 \r\n",
       "L 63.540818 143.285366 \r\n",
       "L 60.929688 140.87208 \r\n",
       "L 59.974148 139.970732 \r\n",
       "L 57.878125 137.921685 \r\n",
       "L 56.609158 136.656098 \r\n",
       "L 54.826563 134.81116 \r\n",
       "L 53.434102 133.341463 \r\n",
       "L 51.775 131.521664 \r\n",
       "L 50.438125 130.026829 \r\n",
       "L 48.723438 128.031284 \r\n",
       "L 47.611186 126.712195 \r\n",
       "L 45.671875 124.314375 \r\n",
       "L 44.943979 123.397561 \r\n",
       "L 42.620313 120.340732 \r\n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"LineCollection_7\">\r\n",
       "    <path clip-path=\"url(#p3f04cb47da)\" d=\"M 43.320212 146.6 \r\n",
       "L 42.620313 145.897747 \r\n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"LineCollection_8\"/>\r\n",
       "   <g id=\"line2d_9\">\r\n",
       "    <path clip-path=\"url(#p3f04cb47da)\" d=\"M 57.878125 113.453659 \r\n",
       "L 90.157908 79.523401 \r\n",
       "L 114.050411 68.888656 \r\n",
       "L 130.683068 55.46782 \r\n",
       "L 140.480338 50.2359 \r\n",
       "L 154.377843 53.295862 \r\n",
       "L 170.402313 45.948944 \r\n",
       "L 180.010091 45.063315 \r\n",
       "L 185.226633 43.505409 \r\n",
       "L 196.436771 41.3817 \r\n",
       "L 204.705688 46.480288 \r\n",
       "L 210.865508 45.376872 \r\n",
       "L 212.602283 43.331288 \r\n",
       "L 208.23462 44.244637 \r\n",
       "L 209.440293 46.156267 \r\n",
       "L 210.742667 46.494632 \r\n",
       "L 211.293586 45.343049 \r\n",
       "L 210.259088 49.136948 \r\n",
       "L 209.448613 47.192572 \r\n",
       "L 209.676633 43.978941 \r\n",
       "L 213.578461 50.600774 \r\n",
       "\" style=\"fill:none;stroke:#ff7f0e;stroke-linecap:square;stroke-width:1.5;\"/>\r\n",
       "    <defs>\r\n",
       "     <path d=\"M 0 3 \r\n",
       "C 0.795609 3 1.55874 2.683901 2.12132 2.12132 \r\n",
       "C 2.683901 1.55874 3 0.795609 3 0 \r\n",
       "C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132 \r\n",
       "C 1.55874 -2.683901 0.795609 -3 0 -3 \r\n",
       "C -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132 \r\n",
       "C -2.683901 -1.55874 -3 -0.795609 -3 0 \r\n",
       "C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132 \r\n",
       "C -1.55874 2.683901 -0.795609 3 0 3 \r\n",
       "z\r\n",
       "\" id=\"m4cc0da0ae5\" style=\"stroke:#ff7f0e;\"/>\r\n",
       "    </defs>\r\n",
       "    <g clip-path=\"url(#p3f04cb47da)\">\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"57.878125\" xlink:href=\"#m4cc0da0ae5\" y=\"113.453659\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"90.157908\" xlink:href=\"#m4cc0da0ae5\" y=\"79.523401\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"114.050411\" xlink:href=\"#m4cc0da0ae5\" y=\"68.888656\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"130.683068\" xlink:href=\"#m4cc0da0ae5\" y=\"55.46782\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"140.480338\" xlink:href=\"#m4cc0da0ae5\" y=\"50.2359\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"154.377843\" xlink:href=\"#m4cc0da0ae5\" y=\"53.295862\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"170.402313\" xlink:href=\"#m4cc0da0ae5\" y=\"45.948944\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"180.010091\" xlink:href=\"#m4cc0da0ae5\" y=\"45.063315\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"185.226633\" xlink:href=\"#m4cc0da0ae5\" y=\"43.505409\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"196.436771\" xlink:href=\"#m4cc0da0ae5\" y=\"41.3817\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"204.705688\" xlink:href=\"#m4cc0da0ae5\" y=\"46.480288\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"210.865508\" xlink:href=\"#m4cc0da0ae5\" y=\"45.376872\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"212.602283\" xlink:href=\"#m4cc0da0ae5\" y=\"43.331288\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"208.23462\" xlink:href=\"#m4cc0da0ae5\" y=\"44.244637\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.440293\" xlink:href=\"#m4cc0da0ae5\" y=\"46.156267\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"210.742667\" xlink:href=\"#m4cc0da0ae5\" y=\"46.494632\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"211.293586\" xlink:href=\"#m4cc0da0ae5\" y=\"45.343049\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"210.259088\" xlink:href=\"#m4cc0da0ae5\" y=\"49.136948\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.448613\" xlink:href=\"#m4cc0da0ae5\" y=\"47.192572\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.676633\" xlink:href=\"#m4cc0da0ae5\" y=\"43.978941\"/>\r\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"213.578461\" xlink:href=\"#m4cc0da0ae5\" y=\"50.600774\"/>\r\n",
       "    </g>\r\n",
       "   </g>\r\n",
       "   <g id=\"patch_3\">\r\n",
       "    <path d=\"M 42.620312 146.6 \r\n",
       "L 42.620312 10.7 \r\n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"patch_4\">\r\n",
       "    <path d=\"M 237.920313 146.6 \r\n",
       "L 237.920313 10.7 \r\n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"patch_5\">\r\n",
       "    <path d=\"M 42.620313 146.6 \r\n",
       "L 237.920313 146.6 \r\n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n",
       "   </g>\r\n",
       "   <g id=\"patch_6\">\r\n",
       "    <path d=\"M 42.620313 10.7 \r\n",
       "L 237.920313 10.7 \r\n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\r\n",
       "   </g>\r\n",
       "  </g>\r\n",
       " </g>\r\n",
       " <defs>\r\n",
       "  <clipPath id=\"p3f04cb47da\">\r\n",
       "   <rect height=\"135.9\" width=\"195.3\" x=\"42.620312\" y=\"10.7\"/>\r\n",
       "  </clipPath>\r\n",
       " </defs>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def f(x1, x2): return x1 ** 2 + 2 * x2 ** 2 # objective\n",
    "def gradf(x1, x2): return (2 * x1, 4 * x2) # gradient\n",
    "def sgd(x1, x2, s1, s2): # simulate noisy gradient\n",
    "       (g1, g2) = gradf(x1, x2) # compute gradient\n",
    "       (g1, g2) = (g1 + np.random.normal(0.1), g2 + np.random.normal(0.1))\n",
    "       return (x1 -eta * g1, x2 -eta * g2, 0, 0) # update variables\n",
    "def train_2d(trainer):\n",
    "    x1, x2, s1, s2 = -5, -2, 0, 0\n",
    "    results = [(x1, x2)]\n",
    "    for i in range(20):\n",
    "        x1, x2, s1, s2 = trainer(x1, x2, s1, s2)\n",
    "        results.append((x1, x2))\n",
    "    print('epoch %d, x1 %f, x2 %f' % (i + 1, x1, x2))\n",
    "    return results\n",
    "eta = 0.1\n",
    "d2l.show_trace_2d(f, train_2d(sgd))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the iterative trajectory of the independent variable in the SGD is more tortuous than in the gradient descent. This is due to the noise added in the experiment, which reduced the accuracy of the simulated stochastic gradient. In practice, such noise usually comes from individual examples in the training data set.\n",
    "\n",
    "## Summary\n",
    "* If we use a more suitable learning rate and update the independent variable in the opposite direction of the gradient, the value of the objective function might be reduced. Gradient descent repeats this update process until a solution that meets the requirements is obtained.\n",
    "* Problems occur when the learning rate is too small or too large. A suitable learning rate is usually found only after multiple experiments.\n",
    "* When there are more examples in the training data set, it costs more to compute each iteration for gradient descent, so SGD is preferred in these cases.\n",
    "\n",
    "## Exercises\n",
    "* Using a different objective function, observe the iterative trajectory of the independent variable in gradient descent and the SGD.\n",
    "* In the experiment for gradient descent in two-dimensional space, try to use different learning rates to observe and analyze the experimental phenomena."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
